package com.gitee.cliveyuan.tools.codec;

import com.gitee.cliveyuan.tools.Assert;
import com.gitee.cliveyuan.tools.bean.rsa.GenerateKeyPairReq;
import com.gitee.cliveyuan.tools.bean.rsa.RSAKeyPair;
import com.gitee.cliveyuan.tools.bean.rsa.RSARequest;
import com.gitee.cliveyuan.tools.bean.rsa.RSASignReq;
import com.gitee.cliveyuan.tools.bean.rsa.RSAVerifySignReq;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.codec.binary.Base64;
import org.bouncycastle.jce.provider.BouncyCastleProvider;

import javax.crypto.Cipher;
import java.nio.charset.StandardCharsets;
import java.security.Key;
import java.security.KeyFactory;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.Security;
import java.security.Signature;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec;
import java.util.Objects;

/**
 * RSA 工具
 * <ol>
 *     <li>密钥生成</li>
 *     <li>加密/解密</li>
 *     <li>签名/验签</li>
 * </ol>
 */
@Slf4j
public class RSATools {

    private static final int DEFAULT_KEY_SIZE = 2048;
    private static final String DEFAULT_ALGORITHM = "RSA";
    private static final String DEFAULT_CHARSET = StandardCharsets.UTF_8.displayName();
    private static final String DEFAULT_SIGNATURE_ALGORITHM = "SHA256withRSA";

    /**
     * 生成RSA公钥私钥
     */
    public static RSAKeyPair generateKeyPair() {
        try {
            return generateKeyPair(GenerateKeyPairReq.builder().build());
        } catch (NoSuchAlgorithmException e) {
            log.error("verifySign Exception", e);
            throw new IllegalArgumentException("NoSuchAlgorithmException");
        }

    }

    /**
     * 生成RSA公钥私钥
     *
     * @param generateKeyPairReq generateKeyPairReq
     * @return
     * @throws NoSuchAlgorithmException
     */
    public static RSAKeyPair generateKeyPair(GenerateKeyPairReq generateKeyPairReq) throws NoSuchAlgorithmException {
        Assert.notNull(generateKeyPairReq);
        KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance(Objects.nonNull(generateKeyPairReq.getKeyAlgorithm()) ?
                generateKeyPairReq.getKeyAlgorithm() : DEFAULT_ALGORITHM);
        keyPairGenerator.initialize(Objects.nonNull(generateKeyPairReq.getKeySize()) ?
                generateKeyPairReq.getKeySize() : DEFAULT_KEY_SIZE);
        KeyPair keyPair = keyPairGenerator.generateKeyPair();
        PublicKey pubKey = keyPair.getPublic();
        PrivateKey privateKey = keyPair.getPrivate();
        return RSAKeyPair.builder()
                .publicKey(new String(Base64.encodeBase64(pubKey.getEncoded(), false)))
                .privateKey(new String(Base64.encodeBase64(privateKey.getEncoded(), false)))
                .build();
    }

    /**
     * 验签
     *
     * @param data            需要验证签名的数据字符串
     * @param sign            签名字符串(长度：1024-->128 2048-->256)
     * @param publicKeyString RSA公钥
     * @return boolean true : 验证签名成功 false: 验证签名失败
     */
    public static boolean verifySign(String data, String sign, String publicKeyString) {
        return verifySign(RSAVerifySignReq.builder()
                .data(data)
                .sign(sign)
                .publicKeyString(publicKeyString)
                .build());
    }

    /**
     * 验签
     *
     * @param rsaVerifySignReq rsaVerifySignReq
     * @return
     */
    public static boolean verifySign(RSAVerifySignReq rsaVerifySignReq) {
        Assert.notNull(rsaVerifySignReq);
        Assert.notNull(rsaVerifySignReq.getData(), "data is required");
        Assert.notNull(rsaVerifySignReq.getSign(), "sign is required");
        Assert.notNull(rsaVerifySignReq.getPublicKeyString(), "publicKeyString is required");
        try {
            KeyFactory keyFactory = KeyFactory.getInstance(Objects.nonNull(rsaVerifySignReq.getKeyAlgorithm()) ?
                    rsaVerifySignReq.getKeyAlgorithm() : DEFAULT_ALGORITHM);
            byte[] encodedKey = Base64.decodeBase64(rsaVerifySignReq.getPublicKeyString().getBytes());
            PublicKey publicKey = keyFactory.generatePublic(new X509EncodedKeySpec(encodedKey));
            Signature signet = Signature.getInstance(Objects.nonNull(rsaVerifySignReq.getSignatureAlgorithm()) ?
                    rsaVerifySignReq.getSignatureAlgorithm() : DEFAULT_SIGNATURE_ALGORITHM);
            signet.initVerify(publicKey);
            signet.update(rsaVerifySignReq.getData().getBytes(Objects.nonNull(rsaVerifySignReq.getCharset()) ?
                    rsaVerifySignReq.getCharset() : DEFAULT_CHARSET));
            return signet.verify(Base64.decodeBase64(rsaVerifySignReq.getSign().getBytes()));
        } catch (Exception e) {
            log.error("verifySign Exception", e);
        }
        return false;
    }


    /**
     * 用私钥对信息生成数字签名
     *
     * @param data             数据
     * @param privateKeyString 私钥
     * @return
     */
    public static String sign(String data, String privateKeyString) {
        return sign(RSASignReq.builder().data(data).privateKeyString(privateKeyString).build());
    }

    /**
     * 用私钥对信息生成数字签名
     *
     * @param rsaSignReq rsaSignReq
     * @return
     */
    public static String sign(RSASignReq rsaSignReq) {
        Assert.notNull(rsaSignReq);
        Assert.notNull(rsaSignReq.getData(), "data is required");
        Assert.notNull(rsaSignReq.getPrivateKeyString(), "privateKeyString is required");
        try {
            byte[] keyBytes = Base64.decodeBase64(rsaSignReq.getPrivateKeyString().getBytes());
            // 构造PKCS8EncodedKeySpec对象
            PKCS8EncodedKeySpec pkcs8EncodedKeySpec = new PKCS8EncodedKeySpec(keyBytes);
            // 指定加密算法
            KeyFactory keyFactory = KeyFactory.getInstance(Objects.nonNull(rsaSignReq.getKeyAlgorithm()) ?
                    rsaSignReq.getKeyAlgorithm() : DEFAULT_ALGORITHM);
            // 取私钥匙对象
            PrivateKey privateKey2 = keyFactory.generatePrivate(pkcs8EncodedKeySpec);
            // 用私钥对信息生成数字签名
            Signature signature = Signature.getInstance(Objects.nonNull(rsaSignReq.getSignatureAlgorithm()) ?
                    rsaSignReq.getSignatureAlgorithm() : DEFAULT_SIGNATURE_ALGORITHM);
            signature.initSign(privateKey2);
            signature.update(rsaSignReq.getData().getBytes(Objects.nonNull(rsaSignReq.getCharset()) ?
                    rsaSignReq.getCharset() : DEFAULT_CHARSET));
            return new String(Base64.encodeBase64(signature.sign()));
        } catch (Exception e) {
            log.error("sign Exception", e);
        }
        return null;
    }

    /**
     * 用公钥加密
     *
     * @return
     */
    public static byte[] encrypt(RSARequest rsaRequest) {
        Assert.notNull(rsaRequest);
        Assert.notNull(rsaRequest.getData(), "data is required");
        Assert.notNull(rsaRequest.getKeyString(), "publicKeyString is required");
        try {
            // 对公钥解密
            byte[] keyBytes = decodeBase64(rsaRequest.getKeyString());
            // 取得公钥
            X509EncodedKeySpec x509KeySpec = new X509EncodedKeySpec(keyBytes);
            String algorithm = Objects.nonNull(rsaRequest.getKeyAlgorithm()) ?
                    rsaRequest.getKeyAlgorithm() : DEFAULT_ALGORITHM;

            KeyFactory keyFactory;
            if (Objects.nonNull(rsaRequest.getProvider())) {
                keyFactory = KeyFactory.getInstance(algorithm, rsaRequest.getProvider());
            } else {
                keyFactory = KeyFactory.getInstance(algorithm);
            }
            Key publicKey = keyFactory.generatePublic(x509KeySpec);
            // 对数据加密
            String cipherAlgorithm = keyFactory.getAlgorithm();
            Cipher cipher;
            if (Objects.nonNull(rsaRequest.getProvider())) {
                cipher = Cipher.getInstance(cipherAlgorithm, rsaRequest.getProvider());
            } else {
                cipher = Cipher.getInstance(cipherAlgorithm);
            }
            cipher.init(Cipher.ENCRYPT_MODE, publicKey);
            return cipher.doFinal(rsaRequest.getData());
        } catch (Exception e) {
            log.error("encrypt Exception", e);
        }
        return null;
    }

    /**
     * 加密
     *
     * @param data         数据
     * @param publicKeyStr 公钥
     * @return
     */
    public static String encrypt(String data, String publicKeyStr) {
        try {
            return Base64.encodeBase64String(encrypt(RSARequest.builder()
                    .data(data.getBytes())
                    .keyString(publicKeyStr)
                    .build()));
        } catch (Exception e) {
            log.error("encrypt Exception", e);
        }
        return null;
    }

    /**
     * 用私钥解密
     *
     * @return
     */
    public static byte[] decrypt(RSARequest rsaRequest) {
        Security.addProvider(new BouncyCastleProvider());
        Assert.notNull(rsaRequest);
        Assert.notNull(rsaRequest.getData(), "data is required");
        Assert.notNull(rsaRequest.getKeyString(), "privateKeyString is required");
        try {
            // 对密钥解密
            byte[] keyBytes = decodeBase64(rsaRequest.getKeyString());
            // 取得私钥
            PKCS8EncodedKeySpec pkcs8KeySpec = new PKCS8EncodedKeySpec(keyBytes);
            KeyFactory keyFactory;
            String algorithm = Objects.nonNull(rsaRequest.getKeyAlgorithm()) ?
                    rsaRequest.getKeyAlgorithm() : DEFAULT_ALGORITHM;
            if (Objects.nonNull(rsaRequest.getProvider())) {
                keyFactory = KeyFactory.getInstance(algorithm, rsaRequest.getProvider());
            } else {
                keyFactory = KeyFactory.getInstance(algorithm);
            }
            Key privateKey = keyFactory.generatePrivate(pkcs8KeySpec);
            // 对数据解密
            String cipherAlgorithm = rsaRequest.getCipherAlgorithm();
            if (Objects.isNull(cipherAlgorithm)) {
                cipherAlgorithm = keyFactory.getAlgorithm();
            }
            Cipher cipher;
            if (Objects.nonNull(rsaRequest.getProvider())) {
                cipher = Cipher.getInstance(cipherAlgorithm, rsaRequest.getProvider());
            } else {
                cipher = Cipher.getInstance(cipherAlgorithm);
            }
            cipher.init(Cipher.DECRYPT_MODE, privateKey);
            return cipher.doFinal(rsaRequest.getData());
        } catch (Exception e) {
            log.error("decrypt Exception", e);
        }
        return null;
    }

    /**
     * 解密
     *
     * @param data          数据
     * @param privateKeyStr 私钥
     * @return
     */
    public static String decrypt(String data, String privateKeyStr) {
        try {
            byte[] decrypt = decrypt(RSARequest.builder()
                    .data(Base64.decodeBase64(data))
                    .keyString(privateKeyStr)
                    .build());
            if (Objects.isNull(decrypt)) {
                return null;
            }
            return new String(decrypt);
        } catch (Exception e) {
            log.error("decrypt Exception", e);
        }
        return null;
    }

    /**
     * BASE64解密
     *
     * @param base64String binaryData
     * @return
     */
    private static byte[] decodeBase64(String base64String) {
        return Base64.decodeBase64(base64String);
    }

    /**
     * BASE64加密
     *
     * @param binaryData binaryData
     * @return
     */
    private static String encodeBase64(byte[] binaryData) {
        return new String(Base64.encodeBase64(binaryData));
    }
}
